iT邦幫忙

2023 iThome 鐵人賽

DAY 27
0
AI & Data

機器學習不難嘛系列 第 27

Day27-線性回歸 調整預測線和成本函數

  • 分享至 

  • xImage
  •  

上一篇中我們把資料寫進了x跟y兩個變數中,我們要把這些資料用圖片表示出來。

https://ithelp.ithome.com.tw/upload/images/20231010/20162311fLPDeArK8j.png

接下來要畫出我們的預測線,我會寫一個自定義函數來印出預測線,在用到這個函數時要寫入兩個參數,分別是w和b,他們代表著我們要印出的預測線的斜率和起點高度,如下:

def plot_pred(w, b):
  y_pred = x*w + b
	plt.plot(x, y_pred, color="blue", label="predict_line")
  plt.scatter(x, y, marker="x", color="red", label="real_data")
  plt.title("Height - Weight")
  plt.xlabel("Height(m)")
  plt.ylabel("Weight(kg)")
  plt.xlim([1, 3])
  plt.ylim([40, 100])
  plt.legend()
  plt.show()

plot_pred(0, 50)#參數可以自己調整看看

https://ithelp.ithome.com.tw/upload/images/20231010/20162311JWAgHxOrEJ.png

如果想在圖中調整參數也可以用下面這個試試看,會出現兩個可調整的數字,更方便觀察。

from ipywidgets import interact

interact(plot_pred, w=(-100, 100, 1), b=(0, 100, 1))#w跟b的數值也可以隨意調整

https://ithelp.ithome.com.tw/upload/images/20231010/20162311VPhJMxtdHg.png
學會印出圖片和預測線後我們來講一下成本函數,成本函數是用來判斷我們所畫出的預測線的好壞,舉個例子:

https://ithelp.ithome.com.tw/upload/images/20231010/20162311Z21KWNQr6b.png

https://ithelp.ithome.com.tw/upload/images/20231010/20162311KnlrQ0w2zG.png

上面這兩張圖很明顯看得出誰的預測線比較準對吧,那是因為上面那張圖的成本函數比下面的圖少很多,成本函數是怎麼計算的呢?

線性回歸的成本函數計算可以想成每筆資料到預測線的距離加總(a1-a2)²+(b1-b2)²+(c1-c2)²,上面的圖的成本函數就是(2-2)²+(4-4)²+(6-6)²=0,下面的成本函數為(1-2)²+(2-4)²+(3-6)²=14

回到我們的範例中,模型的成本函數可以用下列程式碼進行計算。

w = 0
b = 0#w跟b可自行調整
y_pred = w*x + b
cost = (y - y_pred)**2
cost.sum() / len(x)

上一篇
Day26-線性回歸 介紹
下一篇
Day28-線性回歸 暴力破解
系列文
機器學習不難嘛30
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言